Skip to content

Fix --do-sample / --enable-bucketing argparse type=bool bug#32

Open
vyom-hai wants to merge 1 commit intoaws-neuron:mainfrom
vyom-hai:fix/do-sample-argparse-type-bool
Open

Fix --do-sample / --enable-bucketing argparse type=bool bug#32
vyom-hai wants to merge 1 commit intoaws-neuron:mainfrom
vyom-hai:fix/do-sample-argparse-type-bool

Conversation

@vyom-hai
Copy link
Copy Markdown

@vyom-hai vyom-hai commented Apr 30, 2026

Summary

  • Fixes a silent argparse bug in main.py: --do-sample False parses as True, --enable-bucketing False parses as True.
  • Switches both flags to argparse.BooleanOptionalAction so the standard --flag / --no-flag idiom works.
  • Affected lines: main.py:113 (--do-sample) and main.py:157 (--enable-bucketing).

The bug

argparse(type=bool) is not a string→bool converter. The type= callable is invoked on the raw CLI argument string, and bool(non_empty_string) is always True in Python. So every textual "off" value the user might reasonably pass is silently coerced to True.

Empirically (with the unpatched main.py):

--do-sample ''       -> args.do_sample = False
--do-sample 'False'  -> args.do_sample = True   <-- silently wrong!
--do-sample 'false'  -> args.do_sample = True
--do-sample 'no'     -> args.do_sample = True
--do-sample '0'      -> args.do_sample = True
--do-sample 'True'   -> args.do_sample = True
--do-sample 'yes'    -> args.do_sample = True
no flag              -> args.do_sample = True   (default)

The only reachable "off" path is --do-sample '', which is awkward to type from most shells and obviously not what users intend. --no-do-sample is rejected outright because the action doesn't declare it. The same is true for --enable-bucketing.

Why it matters for the contest

This isn't a stylistic concern — it materially affects accuracy validation in evaluate_all:

  1. args.do_sample=True (the default, and what every reasonable string also resolves to) is propagated into OnDeviceSamplingConfig(do_sample=True, ...) at main.py:380-383.
  2. OnDeviceSamplingConfig is consumed at trace time by prepare_inference, so do_sample=True is baked into the compiled graph. There's no runtime override.
  3. Inside the validation harness (run_accuracy_checklogit_validationgenerate_fn), the call at main.py:604 is model.generate(..., do_sample=False, ...). That kwarg is silently ignored: when self.on_device_sampling=True, hf_adapter._sample uses outputs.tokens from the on-device multinomial sampler (hf_adapter.py:229-230), not the host-side do_sample argument.
  4. Net effect: every accuracy iteration falls into logit_validation's divergence-handling fallback path, which is slower and non-deterministic. Multiple teams have had to programmatically work around this by mutating args.do_sample = False before prepare_inference.

The fix converts the validation accuracy run into the deterministic greedy/argmax path it was always supposed to use, simply by allowing --no-do-sample to actually parse as False.

The fix

# Before
parser.add_argument("--do-sample", type=bool, default=True)
parser.add_argument("--enable-bucketing", type=bool, default=True)

# After
parser.add_argument(
    "--do-sample",
    action=argparse.BooleanOptionalAction,
    default=True,
    help="Enable multinomial sampling on the on-device sampler. "
         "Pass --no-do-sample to force greedy/argmax (required for "
         "deterministic logit_validation accuracy checks).",
)
parser.add_argument(
    "--enable-bucketing",
    action=argparse.BooleanOptionalAction,
    default=True,
    help="Enable bucketed compilation. Pass --no-enable-bucketing to disable.",
)

argparse.BooleanOptionalAction is the standard library idiom for paired --flag / --no-flag boolean options. It has been part of argparse since Python 3.9; the contest harness already requires Python ≥3.10 transitively via neuronx_distributed.

Test plan

  • Empirically verified that --no-do-sampleFalse, --do-sampleTrue, no-flag → True (default).
  • Empirically verified that --no-enable-bucketingFalse, --enable-bucketingTrue, no-flag → True (default).
  • No change to any other CLI flag, function body, or import (argparse was already imported at main.py:3).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant